{ "cells": [ { "cell_type": "markdown", "id": "yfiI-lSZRuOP", "metadata": { "id": "yfiI-lSZRuOP" }, "source": [ "### **1. S-learner**\n", "\n", "\n", "The first estimator we would like to introduce is the S-learner, also known as a ``single learner\". This is one of the most foundamental learners in HTE esitmation, and is very easy to implement.\n", "\n", "Under three common assumptions in causal inference, i.e. (1) consistency, (2) no unmeasured confounders (NUC), (3) positivity assumption, the heterogeneous treatment effect can be identified by the observed data, where\n", "\\begin{equation*}\n", "\\tau(s)=\\mathbb{E}[R|S,A=1]-\\mathbb{E}[R|S,A=0].\n", "\\end{equation*}\n", "\n", "The basic idea of S-learner is to fit a model for $\\mathbb{E}[R|S,A]$, and then construct a plug-in estimator for it. Specifically, the algorithm can be summarized as below:\n", "\n", "**Step 1:** Estimate the response function $\\mu(s,a):=\\mathbb{E}[R|S=s,A=a]$ with any supervised machine learning algorithm;\n", "\n", "**Step 2:** The estimated HTE of S-learner is given by \n", "\\begin{equation*}\n", "\\hat{\\tau}_{\\text{S-learner}}(s)=\\hat\\mu(s,1)-\\hat\\mu(s,0).\n", "\\end{equation*}\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:55.276902Z", "start_time": "2023-11-12T12:59:54.456667Z" }, "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:55.353538Z", "start_time": "2023-11-12T12:59:55.278183Z" }, "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Drama | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "48.0 | \n", "1193.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "48.0 | \n", "919.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "48.0 | \n", "527.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "48.0 | \n", "1721.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "48.0 | \n", "150.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
65637 | \n", "5878.0 | \n", "3300.0 | \n", "2.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65638 | \n", "5878.0 | \n", "1391.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65639 | \n", "5878.0 | \n", "185.0 | \n", "4.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65640 | \n", "5878.0 | \n", "2232.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65641 | \n", "5878.0 | \n", "426.0 | \n", "3.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65642 rows × 11 columns
\n", "